#!/usr/bin/env python3
"""
Phase 4: Synthetic Control Method for Individual CCA Countries

Constructs synthetic counterfactuals for CCA countries that opened their
capital accounts, following Abadie, Diamond & Hainmueller (2010, 2015).

For each treated country:
1. Match on pre-treatment outcomes and covariates to construct a weighted
   combination of control (non-CCA) countries
2. Compare actual post-treatment CA trajectory to synthetic counterfactual
3. Permutation inference: apply SCM to each control country as "placebo"
   and compare treated gap to distribution of placebo gaps

Target countries (selected for clear opening events):
- Mongolia (MNG): opened 1996, non-Soviet, early opener
- Russia (RUS): opened 2000, largest CCA economy
- Azerbaijan (AZE): opened 2002, oil economy
- Belarus (BLR): opened 2007, late opener
- Tajikistan (TJK): opened 2008, poorest CCA
- Georgia (GEO): opened 2012, Rose Revolution, cleanest political event
- Kyrgyz Republic (KGZ): opened 2016, very late opener

Output:
  scm_weights.csv — Donor weights for each target country
  scm_gaps.csv — Actual minus synthetic CA by year
  scm_placebo_gaps.csv — Placebo gaps for permutation inference
  scm_summary.csv — Summary statistics and p-values
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy.optimize import minimize

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
CAUSAL_DIR = PROJECT_DIR / "causal_identification"
PROCESSED_DIR = CAUSAL_DIR / "data" / "processed"
OUTPUT_DIR = CAUSAL_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CCA_COUNTRIES = [
    'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA',
    'MNG', 'RUS', 'TJK', 'TKM', 'UKR', 'UZB'
]

# Target countries with opening years
TARGETS = {
    'MNG': 1996,
    'RUS': 2000,
    'AZE': 2002,
    'BLR': 2007,
    'TJK': 2008,
    'GEO': 2012,
    'KGZ': 2016,
}

# Matching variables for SCM
MATCH_VARS = ['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'trade_openness',
              'nfa_gdp_lag', 'fiscal_bal_gdp']


# =====================================================================
# Synthetic Control Implementation
# =====================================================================

class SyntheticControl:
    """
    Synthetic Control Method following Abadie et al. (2010).

    Finds donor weights W that minimize the distance between the treated
    unit's pre-treatment characteristics and the weighted average of
    donor units' characteristics.

    min_W ||X1 - X0 * W||² subject to W >= 0, sum(W) = 1
    """

    def __init__(self):
        self.weights = None
        self.donors = None
        self.treated_iso = None
        self.treatment_year = None
        self.pre_fit = None
        self.post_gap = None

    def fit(self, panel, treated_iso, treatment_year, donor_isos,
            outcome_var='ca_gdp', match_vars=None, min_pre_years=3):
        """
        Fit synthetic control.

        Parameters
        ----------
        panel : DataFrame with iso3, year, outcome, covariates
        treated_iso : str, ISO3 of treated country
        treatment_year : int, year of treatment
        donor_isos : list, ISO3 codes of potential donor countries
        outcome_var : str, outcome variable
        match_vars : list, variables to match on (pre-treatment averages)
        min_pre_years : int, minimum pre-treatment years required
        """
        self.treated_iso = treated_iso
        self.treatment_year = treatment_year

        if match_vars is None:
            match_vars = [outcome_var]

        # Get pre-treatment data
        pre = panel[panel['year'] < treatment_year].copy()

        # Treated unit pre-treatment characteristics
        treated_pre = pre[pre['iso3'] == treated_iso]
        if len(treated_pre) < min_pre_years:
            print(f"    WARNING: {treated_iso} has only {len(treated_pre)} "
                  f"pre-treatment years (need {min_pre_years})")
            if len(treated_pre) < 2:
                return None

        # Compute pre-treatment means for matching
        X1 = treated_pre[match_vars].mean().values

        # Donor units pre-treatment characteristics
        valid_donors = []
        X0_list = []
        for d in donor_isos:
            d_pre = pre[pre['iso3'] == d]
            if len(d_pre) < min_pre_years:
                continue
            d_means = d_pre[match_vars].mean().values
            if np.any(np.isnan(d_means)):
                continue
            valid_donors.append(d)
            X0_list.append(d_means)

        if len(valid_donors) < 5:
            print(f"    WARNING: Only {len(valid_donors)} valid donors")
            if len(valid_donors) < 2:
                return None

        X0 = np.array(X0_list).T  # K x J matrix (K matching vars, J donors)
        self.donors = valid_donors

        # Normalize variables for optimization
        X_all = np.column_stack([X1.reshape(-1, 1), X0])
        scale = np.std(X_all, axis=1, keepdims=True)
        scale[scale == 0] = 1.0
        X1_norm = X1 / scale.flatten()
        X0_norm = X0 / scale.flatten()[:, np.newaxis]

        # Optimize weights
        J = len(valid_donors)

        def objective(w):
            diff = X1_norm - X0_norm @ w
            return np.sum(diff ** 2)

        # Constraints: weights sum to 1, all non-negative
        constraints = {'type': 'eq', 'fun': lambda w: np.sum(w) - 1}
        bounds = [(0, 1)] * J
        w0 = np.ones(J) / J  # equal weights initial

        result = minimize(objective, w0, method='SLSQP',
                          bounds=bounds, constraints=constraints,
                          options={'maxiter': 1000, 'ftol': 1e-12})

        self.weights = result.x

        # Pre-treatment fit
        pre_fit_val = X0 @ self.weights
        self.pre_fit_error = np.sqrt(np.mean((X1 - pre_fit_val) ** 2))

        # Compute gaps for all years
        all_years = sorted(panel['year'].unique())
        treated_data = panel[panel['iso3'] == treated_iso].set_index('year')

        # Synthetic = weighted average of donors
        donor_data = {}
        for d in valid_donors:
            dd = panel[panel['iso3'] == d].set_index('year')
            donor_data[d] = dd

        gaps = []
        for yr in all_years:
            if yr not in treated_data.index:
                continue
            actual = treated_data.loc[yr, outcome_var]
            if pd.isna(actual):
                continue

            synth = 0.0
            valid_weight_sum = 0.0
            for d_idx, d in enumerate(valid_donors):
                if yr in donor_data[d].index:
                    d_val = donor_data[d].loc[yr, outcome_var]
                    if not pd.isna(d_val):
                        synth += self.weights[d_idx] * d_val
                        valid_weight_sum += self.weights[d_idx]

            if valid_weight_sum > 0.5:  # at least 50% of weight present
                synth /= valid_weight_sum  # renormalize
                gaps.append({
                    'year': yr,
                    'actual': actual,
                    'synthetic': synth,
                    'gap': actual - synth,
                    'pre_treatment': yr < treatment_year,
                })

        self.gaps = pd.DataFrame(gaps)

        # Pre/post RMSPE
        pre_gaps = self.gaps[self.gaps['pre_treatment']]['gap']
        post_gaps = self.gaps[~self.gaps['pre_treatment']]['gap']

        self.pre_rmspe = np.sqrt(np.mean(pre_gaps ** 2)) if len(pre_gaps) > 0 else np.nan
        self.post_rmspe = np.sqrt(np.mean(post_gaps ** 2)) if len(post_gaps) > 0 else np.nan
        self.rmspe_ratio = self.post_rmspe / self.pre_rmspe if self.pre_rmspe > 0 else np.nan

        return self

    def top_donors(self, n=5):
        """Return top n donors by weight."""
        if self.weights is None:
            return []
        idx = np.argsort(-self.weights)[:n]
        return [(self.donors[i], self.weights[i]) for i in idx if self.weights[i] > 0.01]


# =====================================================================
# Permutation inference
# =====================================================================

def permutation_test(panel, treated_iso, treatment_year, donor_isos,
                     match_vars, n_placebos=None):
    """
    Permutation inference: apply SCM to each donor as if treated.

    The p-value is the rank of the treated unit's post/pre RMSPE ratio
    among all units (treated + placebos).
    """
    if n_placebos is None:
        n_placebos = len(donor_isos)

    # Fit treated SCM
    sc_treated = SyntheticControl()
    sc_treated.fit(panel, treated_iso, treatment_year, donor_isos,
                   match_vars=match_vars)

    if sc_treated.weights is None:
        return None, []

    treated_ratio = sc_treated.rmspe_ratio
    print(f"    Treated {treated_iso}: post/pre RMSPE ratio = {treated_ratio:.4f}")

    # Placebo SCMs
    placebo_ratios = []
    placebo_gaps = []

    donors_to_test = donor_isos[:n_placebos]

    for i, placebo_iso in enumerate(donors_to_test):
        # Use all other donors (excluding the placebo unit) as donor pool
        placebo_donors = [d for d in donor_isos if d != placebo_iso]

        sc_placebo = SyntheticControl()
        result = sc_placebo.fit(panel, placebo_iso, treatment_year,
                                placebo_donors, match_vars=match_vars,
                                min_pre_years=2)
        if result is None:
            continue

        if not np.isnan(sc_placebo.rmspe_ratio) and sc_placebo.pre_rmspe > 0:
            # Exclude placebos with very bad pre-fit (>5x treated pre-RMSPE)
            if sc_placebo.pre_rmspe <= 5 * sc_treated.pre_rmspe:
                placebo_ratios.append(sc_placebo.rmspe_ratio)

                if sc_placebo.gaps is not None:
                    for _, row in sc_placebo.gaps.iterrows():
                        placebo_gaps.append({
                            'iso3': placebo_iso,
                            'year': row['year'],
                            'gap': row['gap'],
                            'pre_treatment': row['pre_treatment'],
                        })

        if (i + 1) % 20 == 0:
            print(f"      Completed {i+1}/{len(donors_to_test)} placebos")

    # P-value: fraction of units with ratio >= treated ratio
    all_ratios = [treated_ratio] + placebo_ratios
    rank = sum(1 for r in all_ratios if r >= treated_ratio)
    p_value = rank / len(all_ratios)

    print(f"    Permutation p-value: {p_value:.4f} "
          f"(rank {rank}/{len(all_ratios)})")

    return p_value, placebo_gaps


# =====================================================================
# Main analysis
# =====================================================================

def load_panel():
    """Load and prepare panel for SCM."""
    df = pd.read_csv(PROCESSED_DIR / "causal_panel.csv", low_memory=False)
    # Keep a wide window for pre-treatment matching
    df = df[(df['year'] >= 1992) & (df['year'] <= 2024)].copy()
    return df


def get_donor_pool(panel):
    """Get valid donor countries (non-CCA with sufficient data)."""
    non_cca = panel[~panel['iso3'].isin(CCA_COUNTRIES)]

    # Require at least 10 years of CA data
    country_coverage = non_cca.groupby('iso3')['ca_gdp'].apply(
        lambda x: x.notna().sum()
    )
    valid = country_coverage[country_coverage >= 10].index.tolist()

    return valid


def run_scm_for_country(panel, treated_iso, treatment_year, donor_pool,
                        n_placebos=50):
    """Run full SCM analysis for one country."""
    print(f"\n{'='*70}")
    print(f"SCM: {treated_iso} (treatment year: {treatment_year})")
    print(f"{'='*70}")

    # Fit SCM
    sc = SyntheticControl()
    result = sc.fit(panel, treated_iso, treatment_year, donor_pool,
                    match_vars=MATCH_VARS)

    if result is None:
        print(f"  SCM failed for {treated_iso}")
        return None

    # Report results
    print(f"\n  Pre-treatment RMSPE: {sc.pre_rmspe:.4f}")
    print(f"  Post-treatment RMSPE: {sc.post_rmspe:.4f}")
    print(f"  Ratio (post/pre): {sc.rmspe_ratio:.4f}")

    # Top donors
    top = sc.top_donors(8)
    print(f"\n  Top donor weights:")
    for iso, w in top:
        print(f"    {iso}: {w:.4f} ({100*w:.1f}%)")

    total_top = sum(w for _, w in top)
    print(f"    (top {len(top)} donors account for {100*total_top:.1f}% of weight)")

    # Gap table
    print(f"\n  Year-by-year gaps (actual - synthetic CA/GDP):")
    print(f"  {'Year':>6s} {'Actual':>8s} {'Synth':>8s} {'Gap':>8s} {'Period':>12s}")
    print(f"  {'-'*45}")

    for _, row in sc.gaps.iterrows():
        period = "PRE" if row['pre_treatment'] else "POST"
        print(f"  {int(row['year']):>6d} {row['actual']:>8.2f} {row['synthetic']:>8.2f} "
              f"{row['gap']:>+8.2f} {period:>12s}")

    # Average post-treatment gap
    post_gaps = sc.gaps[~sc.gaps['pre_treatment']]['gap']
    if len(post_gaps) > 0:
        avg_gap = post_gaps.mean()
        print(f"\n  Average post-treatment gap: {avg_gap:+.2f}pp")

    # Permutation inference
    print(f"\n  Running permutation test ({n_placebos} placebos)...")
    p_value, placebo_gaps = permutation_test(
        panel, treated_iso, treatment_year, donor_pool,
        match_vars=MATCH_VARS, n_placebos=n_placebos
    )

    return {
        'sc': sc,
        'p_value': p_value,
        'placebo_gaps': placebo_gaps,
    }


if __name__ == '__main__':
    print("=" * 70)
    print("PHASE 4: SYNTHETIC CONTROL METHOD")
    print("=" * 70)

    panel = load_panel()
    print(f"Panel: {len(panel)} obs, {panel['iso3'].nunique()} countries")

    donor_pool = get_donor_pool(panel)
    print(f"Donor pool: {len(donor_pool)} countries")

    # Run SCM for each target country
    all_weights = []
    all_gaps = []
    all_placebo_gaps = []
    summary_rows = []

    for treated_iso, treatment_year in TARGETS.items():
        result = run_scm_for_country(
            panel, treated_iso, treatment_year, donor_pool,
            n_placebos=60  # balance speed vs precision
        )

        if result is None:
            continue

        sc = result['sc']

        # Collect weights
        for d_idx, d in enumerate(sc.donors):
            if sc.weights[d_idx] > 0.001:
                all_weights.append({
                    'treated': treated_iso,
                    'donor': d,
                    'weight': sc.weights[d_idx],
                })

        # Collect gaps
        for _, row in sc.gaps.iterrows():
            all_gaps.append({
                'treated': treated_iso,
                'treatment_year': treatment_year,
                'year': row['year'],
                'actual': row['actual'],
                'synthetic': row['synthetic'],
                'gap': row['gap'],
                'pre_treatment': row['pre_treatment'],
            })

        # Collect placebo gaps
        for pg in result['placebo_gaps']:
            pg['treated_country'] = treated_iso
            all_placebo_gaps.append(pg)

        # Summary
        post_gaps = sc.gaps[~sc.gaps['pre_treatment']]['gap']
        summary_rows.append({
            'country': treated_iso,
            'treatment_year': treatment_year,
            'pre_rmspe': sc.pre_rmspe,
            'post_rmspe': sc.post_rmspe,
            'rmspe_ratio': sc.rmspe_ratio,
            'avg_post_gap': post_gaps.mean() if len(post_gaps) > 0 else np.nan,
            'n_post_years': len(post_gaps),
            'p_value': result['p_value'],
            'n_donors_used': sum(1 for w in sc.weights if w > 0.01),
            'top_donor': sc.top_donors(1)[0][0] if sc.top_donors(1) else '',
            'top_donor_weight': sc.top_donors(1)[0][1] if sc.top_donors(1) else 0,
        })

    # Save results
    if all_weights:
        pd.DataFrame(all_weights).to_csv(OUTPUT_DIR / "scm_weights.csv", index=False)
    if all_gaps:
        pd.DataFrame(all_gaps).to_csv(OUTPUT_DIR / "scm_gaps.csv", index=False)
    if all_placebo_gaps:
        pd.DataFrame(all_placebo_gaps).to_csv(OUTPUT_DIR / "scm_placebo_gaps.csv", index=False)

    if summary_rows:
        summary_df = pd.DataFrame(summary_rows)
        summary_df.to_csv(OUTPUT_DIR / "scm_summary.csv", index=False)

        print("\n\n" + "=" * 70)
        print("SCM SUMMARY")
        print("=" * 70)
        print(f"\n{'Country':>8s} {'Treat yr':>9s} {'Pre RMSPE':>10s} {'Post RMSPE':>11s} "
              f"{'Ratio':>7s} {'Avg gap':>8s} {'p-value':>8s}")
        print("-" * 70)

        for _, row in summary_df.iterrows():
            sig = '***' if row['p_value'] < 0.01 else ('**' if row['p_value'] < 0.05 else ('*' if row['p_value'] < 0.1 else ''))
            print(f"{row['country']:>8s} {int(row['treatment_year']):>9d} "
                  f"{row['pre_rmspe']:>10.3f} {row['post_rmspe']:>11.3f} "
                  f"{row['rmspe_ratio']:>7.2f} {row['avg_post_gap']:>+8.2f} "
                  f"{row['p_value']:>8.4f} {sig}")

    print("\n" + "=" * 70)
    print("PHASE 4 COMPLETE")
    print("=" * 70)
    print(f"\nOutput files:")
    for f in sorted(OUTPUT_DIR.glob("scm_*.csv")):
        print(f"  {f.name}")
